from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# model_id = './model_save/charent/ChatLM-mini-Chinese' # 'charent/ChatLM-mini-Chinese'

# 如果无法连接huggingface，打开以下两行代码的注释，将从modelscope下载模型文件，模型文件保存到'./model_save'目录
# from modelscope import snapshot_download
# model_id = snapshot_download(model_id, cache_dir='./model_save')


# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# tokenizer = AutoTokenizer.from_pretrained(model_id)

# model = AutoModelForSeq2SeqLM.from_pretrained(model_id, trust_remote_code=True).to(device)

# txt = '南京有什么好玩的地方吗？'

# encode_ids = tokenizer([txt])
# input_ids, attention_mask = torch.LongTensor(encode_ids['input_ids']), torch.LongTensor(encode_ids['attention_mask'])

# outs = model.my_generate(
#   input_ids=input_ids.to(device),
#   attention_mask=attention_mask.to(device),
#   max_seq_len=256,
#   search_type='beam',
# )

# outs_txt = tokenizer.batch_decode(outs.cpu().numpy(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
# print(outs_txt[0])

class ChatLMMiniChinese:

  def __init__(self) -> None:
    model_id = 'D:/coding/toy/wuta_admin/test/model_save/charent/ChatLM-mini-Chinese' # 'charent/ChatLM-mini-Chinese'
    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.tokenizer = AutoTokenizer.from_pretrained(model_id)
    self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id, trust_remote_code=True).to(self.device)

    
  def question(self, text):
    encode_ids = self.tokenizer([text])
    input_ids, attention_mask = torch.LongTensor(encode_ids['input_ids']), torch.LongTensor(encode_ids['attention_mask'])

    outs = self.model.my_generate(
      input_ids=input_ids.to(self.device),
      attention_mask=attention_mask.to(self.device),
      max_seq_len=256,
      search_type='beam',
    )
    outs_txt = self.tokenizer.batch_decode(outs.cpu().numpy(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return outs_txt